/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.imaging.formats.png;

import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;

import org.apache.commons.imaging.ImagingException;
import org.apache.commons.imaging.PixelDensity;
import org.apache.commons.imaging.common.Allocator;
import org.apache.commons.imaging.internal.Debug;
import org.apache.commons.imaging.palette.Palette;
import org.apache.commons.imaging.palette.PaletteFactory;

public class PngWriter {

    /*
     * 1. IHDR: image header, which is the first chunk in a PNG data stream. 2. PLTE: palette table associated with indexed PNG images. 3. IDAT: image data
     * chunks. 4. IEND: image trailer, which is the last chunk in a PNG data stream.
     *
     * The remaining 14 chunk types are termed ancillary chunk types, which encoders may generate and decoders may interpret.
     *
     * 1. Transparency information: tRNS (see 11.3.2: Transparency information). 2. Color space information: cHRM, gAMA, iCCP, sBIT, sRGB (see 11.3.3: Color
     * space information). 3. Textual information: iTXt, tEXt, zTXt (see 11.3.4: Textual information). 4. Miscellaneous information: bKGD, hIST, pHYs, sPLT (see
     * 11.3.5: Miscellaneous information). 5. Time information: tIME (see 11.3.6: Time stamp information).
     */

    private static final class ImageHeader {
        public final int width;
        public final int height;
        public final byte bitDepth;
        public final PngColorType pngColorType;
        public final byte compressionMethod;
        public final byte filterMethod;
        public final InterlaceMethod interlaceMethod;

        ImageHeader(final int width, final int height, final byte bitDepth, final PngColorType pngColorType, final byte compressionMethod,
                final byte filterMethod, final InterlaceMethod interlaceMethod) {
            this.width = width;
            this.height = height;
            this.bitDepth = bitDepth;
            this.pngColorType = pngColorType;
            this.compressionMethod = compressionMethod;
            this.filterMethod = filterMethod;
            this.interlaceMethod = interlaceMethod;
        }

    }

    private byte[] deflate(final byte[] bytes) throws IOException {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            try (DeflaterOutputStream dos = new DeflaterOutputStream(baos)) {
                dos.write(bytes);
                // dos.flush() doesn't work - we must close it before baos.toByteArray()
            }
            return baos.toByteArray();
        }
    }

    private byte getBitDepth(final PngColorType pngColorType, final PngImagingParameters params) {
        final byte depth = params.getBitDepth();

        return pngColorType.isBitDepthAllowed(depth) ? depth : PngImagingParameters.DEFAULT_BIT_DEPTH;
    }

    private boolean isValidISO_8859_1(final String s) {
        final String roundtrip = new String(s.getBytes(StandardCharsets.ISO_8859_1), StandardCharsets.ISO_8859_1);
        return s.equals(roundtrip);
    }

    private void writeChunk(final OutputStream os, final ChunkType chunkType, final byte[] data) throws IOException {
        final int dataLength = data == null ? 0 : data.length;
        writeInt(os, dataLength);
        os.write(chunkType.array);
        if (data != null) {
            os.write(data);
        }

        final PngCrc pngCrc = new PngCrc();

        final long crc1 = pngCrc.startPartialCrc(chunkType.array, chunkType.array.length);
        final long crc2 = data == null ? crc1 : pngCrc.continuePartialCrc(crc1, data, data.length);
        final int crc = (int) pngCrc.finishPartialCrc(crc2);

        writeInt(os, crc);
    }

    private void writeChunkIDAT(final OutputStream os, final byte[] bytes) throws IOException {
        writeChunk(os, ChunkType.IDAT, bytes);
    }

    private void writeChunkIEND(final OutputStream os) throws IOException {
        writeChunk(os, ChunkType.IEND, null);
    }

    private void writeChunkIHDR(final OutputStream os, final ImageHeader value) throws IOException {
        final ByteArrayOutputStream baos = new ByteArrayOutputStream();
        writeInt(baos, value.width);
        writeInt(baos, value.height);
        baos.write(0xff & value.bitDepth);
        baos.write(0xff & value.pngColorType.getValue());
        baos.write(0xff & value.compressionMethod);
        baos.write(0xff & value.filterMethod);
        baos.write(0xff & value.interlaceMethod.ordinal());

        writeChunk(os, ChunkType.IHDR, baos.toByteArray());
    }

    private void writeChunkiTXt(final OutputStream os, final AbstractPngText.Itxt text) throws IOException, ImagingException {
        if (!isValidISO_8859_1(text.keyword)) {
            throw new ImagingException("PNG tEXt chunk keyword is not ISO-8859-1: " + text.keyword);
        }
        if (!isValidISO_8859_1(text.languageTag)) {
            throw new ImagingException("PNG tEXt chunk language tag is not ISO-8859-1: " + text.languageTag);
        }

        final ByteArrayOutputStream baos = new ByteArrayOutputStream();

        // keyword
        baos.write(text.keyword.getBytes(StandardCharsets.ISO_8859_1));
        baos.write(0);

        baos.write(1); // compressed flag, true
        baos.write(PngConstants.COMPRESSION_DEFLATE_INFLATE); // compression method

        // language tag
        baos.write(text.languageTag.getBytes(StandardCharsets.ISO_8859_1));
        baos.write(0);

        // translated keyword
        baos.write(text.translatedKeyword.getBytes(StandardCharsets.UTF_8));
        baos.write(0);

        baos.write(deflate(text.text.getBytes(StandardCharsets.UTF_8)));

        writeChunk(os, ChunkType.iTXt, baos.toByteArray());
    }

    private void writeChunkPHYS(final OutputStream os, final int xPPU, final int yPPU, final byte units) throws IOException {
        final byte[] bytes = new byte[9];
        bytes[0] = (byte) (0xff & xPPU >> 24);
        bytes[1] = (byte) (0xff & xPPU >> 16);
        bytes[2] = (byte) (0xff & xPPU >> 8);
        bytes[3] = (byte) (0xff & xPPU >> 0);
        bytes[4] = (byte) (0xff & yPPU >> 24);
        bytes[5] = (byte) (0xff & yPPU >> 16);
        bytes[6] = (byte) (0xff & yPPU >> 8);
        bytes[7] = (byte) (0xff & yPPU >> 0);
        bytes[8] = units;
        writeChunk(os, ChunkType.pHYs, bytes);
    }

    private void writeChunkPLTE(final OutputStream os, final Palette palette) throws IOException {
        final int length = palette.length();
        final byte[] bytes = Allocator.byteArray(length * 3);

        // Debug.debug("length", length);
        for (int i = 0; i < length; i++) {
            final int rgb = palette.getEntry(i);
            final int index = i * 3;
            // Debug.debug("index", index);
            bytes[index + 0] = (byte) (0xff & rgb >> 16);
            bytes[index + 1] = (byte) (0xff & rgb >> 8);
            bytes[index + 2] = (byte) (0xff & rgb >> 0);
        }

        writeChunk(os, ChunkType.PLTE, bytes);
    }

    private void writeChunkSCAL(final OutputStream os, final double xUPP, final double yUPP, final byte units) throws IOException {
        final ByteArrayOutputStream baos = new ByteArrayOutputStream();

        // unit specifier
        baos.write(units);

        // units per pixel, x-axis
        baos.write(String.valueOf(xUPP).getBytes(StandardCharsets.ISO_8859_1));
        baos.write(0);

        baos.write(String.valueOf(yUPP).getBytes(StandardCharsets.ISO_8859_1));

        writeChunk(os, ChunkType.sCAL, baos.toByteArray());
    }

    private void writeChunktEXt(final OutputStream os, final AbstractPngText.Text text) throws IOException, ImagingException {
        if (!isValidISO_8859_1(text.keyword)) {
            throw new ImagingException("PNG tEXt chunk keyword is not ISO-8859-1: " + text.keyword);
        }
        if (!isValidISO_8859_1(text.text)) {
            throw new ImagingException("PNG tEXt chunk text is not ISO-8859-1: " + text.text);
        }

        final ByteArrayOutputStream baos = new ByteArrayOutputStream();

        // keyword
        baos.write(text.keyword.getBytes(StandardCharsets.ISO_8859_1));
        baos.write(0);

        // text
        baos.write(text.text.getBytes(StandardCharsets.ISO_8859_1));

        writeChunk(os, ChunkType.tEXt, baos.toByteArray());
    }

    private void writeChunkTRNS(final OutputStream os, final Palette palette) throws IOException {
        final byte[] bytes = Allocator.byteArray(palette.length());

        for (int i = 0; i < bytes.length; i++) {
            bytes[i] = (byte) (0xff & palette.getEntry(i) >> 24);
        }

        writeChunk(os, ChunkType.tRNS, bytes);
    }

    private void writeChunkXmpiTXt(final OutputStream os, final String xmpXml) throws IOException {

        final ByteArrayOutputStream baos = new ByteArrayOutputStream();

        // keyword
        baos.write(PngConstants.XMP_KEYWORD.getBytes(StandardCharsets.ISO_8859_1));
        baos.write(0);

        baos.write(1); // compressed flag, true
        baos.write(PngConstants.COMPRESSION_DEFLATE_INFLATE); // compression method

        baos.write(0); // language tag (ignore). TODO

        // translated keyword
        baos.write(PngConstants.XMP_KEYWORD.getBytes(StandardCharsets.UTF_8));
        baos.write(0);

        baos.write(deflate(xmpXml.getBytes(StandardCharsets.UTF_8)));

        writeChunk(os, ChunkType.iTXt, baos.toByteArray());
    }

    private void writeChunkzTXt(final OutputStream os, final AbstractPngText.Ztxt text) throws IOException, ImagingException {
        if (!isValidISO_8859_1(text.keyword)) {
            throw new ImagingException("PNG zTXt chunk keyword is not ISO-8859-1: " + text.keyword);
        }
        if (!isValidISO_8859_1(text.text)) {
            throw new ImagingException("PNG zTXt chunk text is not ISO-8859-1: " + text.text);
        }

        final ByteArrayOutputStream baos = new ByteArrayOutputStream();

        // keyword
        baos.write(text.keyword.getBytes(StandardCharsets.ISO_8859_1));
        baos.write(0);

        // compression method
        baos.write(PngConstants.COMPRESSION_DEFLATE_INFLATE);

        // text
        baos.write(deflate(text.text.getBytes(StandardCharsets.ISO_8859_1)));

        writeChunk(os, ChunkType.zTXt, baos.toByteArray());
    }

    /*
     * between two chunk types indicates alternatives. Table 5.3 - Chunk ordering rules Critical chunks (shall appear in this order, except PLTE is optional)
     * Chunk name Multiple allowed Ordering constraints IHDR No Shall be first PLTE No Before first IDAT IDAT Yes Multiple IDAT chunks shall be consecutive IEND
     * No Shall be last Ancillary chunks (need not appear in this order) Chunk name Multiple allowed Ordering constraints cHRM No Before PLTE and IDAT gAMA No
     * Before PLTE and IDAT iCCP No Before PLTE and IDAT. If the iCCP chunk is present, the sRGB chunk should not be present. sBIT No Before PLTE and IDAT sRGB
     * No Before PLTE and IDAT. If the sRGB chunk is present, the iCCP chunk should not be present. bKGD No After PLTE; before IDAT hIST No After PLTE; before
     * IDAT tRNS No After PLTE; before IDAT pHYs No Before IDAT sCAL No Before IDAT sPLT Yes Before IDAT tIME No None iTXt Yes None tEXt Yes None zTXt Yes None
     */

    /**
     * Writes an image to an output stream.
     *
     * @param src            The image to write.
     * @param os             The output stream to write to.
     * @param params         The parameters to use (can be {@code NULL} to use the default {@link PngImagingParameters}).
     * @param paletteFactory The palette factory to use (can be {@code NULL} to use the default {@link PaletteFactory}).
     * @throws ImagingException When errors are detected.
     * @throws IOException      When IO problems occur.
     */
    public void writeImage(final BufferedImage src, final OutputStream os, PngImagingParameters params, PaletteFactory paletteFactory)
            throws ImagingException, IOException {
        if (params == null) {
            params = new PngImagingParameters();
        }
        if (paletteFactory == null) {
            paletteFactory = new PaletteFactory();
        }
        final int compressionLevel = Deflater.DEFAULT_COMPRESSION;

        final int width = src.getWidth();
        final int height = src.getHeight();

        final boolean hasAlpha = paletteFactory.hasTransparency(src);
        Debug.debug("hasAlpha: " + hasAlpha);
        // int transparency = paletteFactory.getTransparency(src);

        boolean isGrayscale = paletteFactory.isGrayscale(src);
        Debug.debug("isGrayscale: " + isGrayscale);

        final PngColorType pngColorType;
        {
            final boolean forceIndexedColor = params.isForceIndexedColor();
            final boolean forceTrueColor = params.isForceTrueColor();

            if (forceIndexedColor && forceTrueColor) {
                throw new ImagingException("Params: Cannot force both indexed and true color modes");
            }
            if (forceIndexedColor) {
                pngColorType = PngColorType.INDEXED_COLOR;
            } else if (forceTrueColor) {
                pngColorType = hasAlpha ? PngColorType.TRUE_COLOR_WITH_ALPHA : PngColorType.TRUE_COLOR;
                isGrayscale = false;
            } else {
                pngColorType = PngColorType.getColorType(hasAlpha, isGrayscale);
            }
            Debug.debug("colorType: " + pngColorType);
        }

        final byte bitDepth = getBitDepth(pngColorType, params);
        Debug.debug("bitDepth: " + bitDepth);

        final int sampleDepth;
        if (pngColorType == PngColorType.INDEXED_COLOR) {
            sampleDepth = 8;
        } else {
            sampleDepth = bitDepth;
        }
        Debug.debug("sampleDepth: " + sampleDepth);

        {
            PngConstants.PNG_SIGNATURE.writeTo(os);
        }
        {
            // IHDR must be first

            final byte compressionMethod = PngConstants.COMPRESSION_TYPE_INFLATE_DEFLATE;
            final byte filterMethod = PngConstants.FILTER_METHOD_ADAPTIVE;
            final InterlaceMethod interlaceMethod = InterlaceMethod.NONE;

            final ImageHeader imageHeader = new ImageHeader(width, height, bitDepth, pngColorType, compressionMethod, filterMethod, interlaceMethod);

            writeChunkIHDR(os, imageHeader);
        }

        // {
        // sRGB No Before PLTE and IDAT. If the sRGB chunk is present, the
        // iCCP chunk should not be present.

        // charles
        // }

        Palette palette = null;
        if (pngColorType == PngColorType.INDEXED_COLOR) {
            // PLTE No Before first IDAT

            final int maxColors = 256;

            if (hasAlpha) {
                palette = paletteFactory.makeQuantizedRgbaPalette(src, hasAlpha, maxColors);
                writeChunkPLTE(os, palette);
                writeChunkTRNS(os, palette);
            } else {
                palette = paletteFactory.makeQuantizedRgbPalette(src, maxColors);
                writeChunkPLTE(os, palette);
            }
        }

        final Object pixelDensityObj = params.getPixelDensity();
        if (pixelDensityObj != null) {
            final PixelDensity pixelDensity = (PixelDensity) pixelDensityObj;
            if (pixelDensity.isUnitless()) {
                writeChunkPHYS(os, (int) Math.round(pixelDensity.getRawHorizontalDensity()), (int) Math.round(pixelDensity.getRawVerticalDensity()), (byte) 0);
            } else {
                writeChunkPHYS(os, (int) Math.round(pixelDensity.horizontalDensityMetres()), (int) Math.round(pixelDensity.verticalDensityMetres()), (byte) 1);
            }
        }

        final PhysicalScale physicalScale = params.getPhysicalScale();
        if (physicalScale != null) {
            writeChunkSCAL(os, physicalScale.getHorizontalUnitsPerPixel(), physicalScale.getVerticalUnitsPerPixel(),
                    physicalScale.isInMeters() ? (byte) 1 : (byte) 2);
        }

        final String xmpXml = params.getXmpXml();
        if (xmpXml != null) {
            writeChunkXmpiTXt(os, xmpXml);
        }

        final List<? extends AbstractPngText> outputTexts = params.getTextChunks();
        if (outputTexts != null) {
            for (final AbstractPngText text : outputTexts) {
                if (text instanceof AbstractPngText.Text) {
                    writeChunktEXt(os, (AbstractPngText.Text) text);
                } else if (text instanceof AbstractPngText.Ztxt) {
                    writeChunkzTXt(os, (AbstractPngText.Ztxt) text);
                } else if (text instanceof AbstractPngText.Itxt) {
                    writeChunkiTXt(os, (AbstractPngText.Itxt) text);
                } else {
                    throw new ImagingException("Unknown text to embed in PNG: " + text);
                }
            }
        }

        {
            // Debug.debug("writing IDAT");

            // IDAT Yes Multiple IDAT chunks shall be consecutive

            // 28 March 2022. At this time, we only apply the predictor
            // for non-grayscale, true-color images. This choice is made
            // out of caution and is not necessarily required by the PNG
            // spec. We may broaden the use of predictors in future versions.
            final boolean usePredictor = params.isPredictorEnabled() && !isGrayscale && palette == null;

            final byte[] uncompressed;
            if (!usePredictor) {
                final ByteArrayOutputStream baos = new ByteArrayOutputStream();

                final boolean useAlpha = pngColorType == PngColorType.GREYSCALE_WITH_ALPHA || pngColorType == PngColorType.TRUE_COLOR_WITH_ALPHA;

                final int[] row = Allocator.intArray(width);
                for (int y = 0; y < height; y++) {
                    // Debug.debug("y", y + "/" + height);
                    src.getRGB(0, y, width, 1, row, 0, width);

                    baos.write(FilterType.NONE.ordinal());
                    for (int x = 0; x < width; x++) {
                        final int argb = row[x];

                        if (palette != null) {
                            final int index = palette.getPaletteIndex(argb);
                            baos.write(0xff & index);
                        } else {
                            final int alpha = 0xff & argb >> 24;
                            final int red = 0xff & argb >> 16;
                            final int green = 0xff & argb >> 8;
                            final int blue = 0xff & argb >> 0;

                            if (isGrayscale) {
                                final int gray = (red + green + blue) / 3;
                                // if (y == 0)
                                // {
                                // Debug.debug("gray: " + x + ", " + y +
                                // " argb: 0x"
                                // + Integer.toHexString(argb) + " gray: 0x"
                                // + Integer.toHexString(gray));
                                // // Debug.debug(x + ", " + y + " gray", gray);
                                // // Debug.debug(x + ", " + y + " gray", gray);
                                // Debug.debug(x + ", " + y + " gray", gray +
                                // " " + Integer.toHexString(gray));
                                // Debug.debug();
                                // }
                                baos.write(gray);
                            } else {
                                baos.write(red);
                                baos.write(green);
                                baos.write(blue);
                            }
                            if (useAlpha) {
                                baos.write(alpha);
                            }
                        }
                    }
                }
                uncompressed = baos.toByteArray();
            } else {
                final ByteArrayOutputStream baos = new ByteArrayOutputStream();

                final boolean useAlpha = pngColorType == PngColorType.GREYSCALE_WITH_ALPHA || pngColorType == PngColorType.TRUE_COLOR_WITH_ALPHA;

                final int[] row = Allocator.intArray(width);
                for (int y = 0; y < height; y++) {
                    // Debug.debug("y", y + "/" + height);
                    src.getRGB(0, y, width, 1, row, 0, width);

                    int priorA = 0;
                    int priorR = 0;
                    int priorG = 0;
                    int priorB = 0;
                    baos.write(FilterType.SUB.ordinal());
                    for (int x = 0; x < width; x++) {
                        final int argb = row[x];
                        final int alpha = 0xff & argb >> 24;
                        final int red = 0xff & argb >> 16;
                        final int green = 0xff & argb >> 8;
                        final int blue = 0xff & argb;

                        baos.write(red - priorR);
                        baos.write(green - priorG);
                        baos.write(blue - priorB);
                        priorR = red;
                        priorG = green;
                        priorB = blue;

                        if (useAlpha) {
                            baos.write(alpha - priorA);
                            priorA = alpha;
                        }
                    }
                }
                uncompressed = baos.toByteArray();
            }

            // Debug.debug("uncompressed", uncompressed.length);

            final ByteArrayOutputStream baos = new ByteArrayOutputStream();
            final int chunkSize = 256 * 1024;
            final Deflater deflater = new Deflater(compressionLevel);
            final DeflaterOutputStream dos = new DeflaterOutputStream(baos, deflater, chunkSize);

            for (int index = 0; index < uncompressed.length; index += chunkSize) {
                final int end = Math.min(uncompressed.length, index + chunkSize);
                final int length = end - index;

                dos.write(uncompressed, index, length);
                dos.flush();
                baos.flush();

                final byte[] compressed = baos.toByteArray();
                baos.reset();
                if (compressed.length > 0) {
                    // Debug.debug("compressed", compressed.length);
                    writeChunkIDAT(os, compressed);
                }

            }
            {
                dos.finish();
                final byte[] compressed = baos.toByteArray();
                if (compressed.length > 0) {
                    // Debug.debug("compressed final", compressed.length);
                    writeChunkIDAT(os, compressed);
                }
            }
        }

        {
            // IEND No Shall be last

            writeChunkIEND(os);
        }

        /*
         * Ancillary chunks (need not appear in this order) Chunk name Multiple allowed Ordering constraints cHRM No Before PLTE and IDAT gAMA No Before PLTE
         * and IDAT iCCP No Before PLTE and IDAT. If the iCCP chunk is present, the sRGB chunk should not be present. sBIT No Before PLTE and IDAT sRGB No
         * Before PLTE and IDAT. If the sRGB chunk is present, the iCCP chunk should not be present. bKGD No After PLTE; before IDAT hIST No After PLTE; before
         * IDAT tRNS No After PLTE; before IDAT pHYs No Before IDAT sCAL No Before IDAT sPLT Yes Before IDAT tIME No None iTXt Yes None tEXt Yes None zTXt Yes
         * None
         */

        os.close();
    } // todo: filter types
      // proper color types
      // srgb, etc.

    private void writeInt(final OutputStream os, final int value) throws IOException {
        os.write(0xff & value >> 24);
        os.write(0xff & value >> 16);
        os.write(0xff & value >> 8);
        os.write(0xff & value >> 0);
    }
}
